--- title: Diffusion Distance Based Loss keywords: fastai sidebar: home_sidebar nb_path: "05c05 Diffusion Distance based Loss - Tests and Visualizations.ipynb" ---
{% raw %}
{% endraw %} {% raw %}

class DiffusionDistanceFlowEmbedder[source]

DiffusionDistanceFlowEmbedder(X, flows, labels, device=device(type='cpu')) :: FETrainer

{% endraw %} {% raw %}
{% endraw %} {% raw %}
from directed_graphs.datasets import directed_swiss_roll_uniform, plot_directed_3d
X, flow, labels = directed_swiss_roll_uniform(num_nodes=1000, num_spirals=2.5, radius=1, height=12, xtilt=0, ytilt=0)
plot_directed_3d(X, flow, labels, mask_prob=0.5)
{% endraw %} {% raw %}
X = torch.tensor(X).float().to(device)
flow = torch.tensor(flow).float().to(device)
X = X.float().to(device)
flow = flow.float().to(device)
{% endraw %} {% raw %}
BOBO_FET = DiffusionDistanceFlowEmbedder(X, flow, labels = labels, device = device)
{% endraw %} {% raw %}
BOBO_FET.fit()
  0%|          | 0/100 [00:00<?, ?it/s]
  1%|          | 1/100 [00:05<08:42,  5.28s/it]
  2%|▏         | 2/100 [00:10<08:39,  5.31s/it]
  3%|▎         | 3/100 [00:15<08:33,  5.30s/it]
  4%|▍         | 4/100 [00:21<08:27,  5.29s/it]
  5%|▌         | 5/100 [00:26<08:21,  5.28s/it]
  6%|▌         | 6/100 [00:31<08:21,  5.33s/it]
  7%|▋         | 7/100 [00:37<08:12,  5.30s/it]
  8%|▊         | 8/100 [00:42<08:07,  5.30s/it]
  9%|▉         | 9/100 [00:47<08:04,  5.32s/it]
 10%|█         | 10/100 [00:53<08:00,  5.34s/it]
 11%|█         | 11/100 [00:58<07:55,  5.34s/it]
 12%|█▏        | 12/100 [01:03<07:53,  5.39s/it]
 13%|█▎        | 13/100 [01:09<07:43,  5.33s/it]
 14%|█▍        | 14/100 [01:14<07:35,  5.30s/it]
 15%|█▌        | 15/100 [01:19<07:27,  5.27s/it]
 16%|█▌        | 16/100 [01:24<07:20,  5.25s/it]
 17%|█▋        | 17/100 [01:30<07:16,  5.26s/it]
 18%|█▊        | 18/100 [01:35<07:10,  5.25s/it]
 19%|█▉        | 19/100 [01:40<07:10,  5.31s/it]
 20%|██        | 20/100 [01:46<07:03,  5.30s/it]
 21%|██        | 21/100 [01:51<06:58,  5.30s/it]
 22%|██▏       | 22/100 [01:56<06:53,  5.30s/it]
 23%|██▎       | 23/100 [02:01<06:48,  5.30s/it]
 24%|██▍       | 24/100 [02:07<06:41,  5.29s/it]
 25%|██▌       | 25/100 [02:12<06:42,  5.36s/it]
 26%|██▌       | 26/100 [02:17<06:33,  5.32s/it]
 27%|██▋       | 27/100 [02:23<06:26,  5.30s/it]
 28%|██▊       | 28/100 [02:28<06:21,  5.29s/it]
 29%|██▉       | 29/100 [02:33<06:16,  5.30s/it]
 30%|███       | 30/100 [02:39<06:10,  5.29s/it]
 31%|███       | 31/100 [02:44<06:05,  5.30s/it]
 32%|███▏      | 32/100 [02:49<06:03,  5.35s/it]
 33%|███▎      | 33/100 [02:55<05:58,  5.35s/it]
 34%|███▍      | 34/100 [03:00<05:51,  5.33s/it]
 35%|███▌      | 35/100 [03:05<05:45,  5.31s/it]
 36%|███▌      | 36/100 [03:11<05:39,  5.30s/it]
 37%|███▋      | 37/100 [03:16<05:33,  5.30s/it]
 38%|███▊      | 38/100 [03:21<05:31,  5.35s/it]
 39%|███▉      | 39/100 [03:27<05:28,  5.39s/it]
 40%|████      | 40/100 [03:32<05:22,  5.37s/it]
 41%|████      | 41/100 [03:37<05:15,  5.34s/it]
 42%|████▏     | 42/100 [03:43<05:09,  5.34s/it]
 43%|████▎     | 43/100 [03:48<05:02,  5.30s/it]
 44%|████▍     | 44/100 [03:53<04:56,  5.29s/it]
 45%|████▌     | 45/100 [03:59<04:53,  5.34s/it]
 46%|████▌     | 46/100 [04:04<04:47,  5.32s/it]
 47%|████▋     | 47/100 [04:09<04:41,  5.31s/it]
 48%|████▊     | 48/100 [04:15<04:36,  5.31s/it]
 49%|████▉     | 49/100 [04:20<04:31,  5.32s/it]
 50%|█████     | 50/100 [04:25<04:25,  5.32s/it]
 51%|█████     | 51/100 [04:31<04:22,  5.37s/it]
 52%|█████▏    | 52/100 [04:36<04:16,  5.34s/it]
 53%|█████▎    | 53/100 [04:41<04:10,  5.32s/it]
 54%|█████▍    | 54/100 [04:47<04:04,  5.32s/it]
 55%|█████▌    | 55/100 [04:52<03:58,  5.31s/it]
 56%|█████▌    | 56/100 [04:57<03:53,  5.31s/it]
 57%|█████▋    | 57/100 [05:03<03:51,  5.38s/it]
 58%|█████▊    | 58/100 [05:08<03:44,  5.34s/it]
 59%|█████▉    | 59/100 [05:13<03:37,  5.32s/it]
 60%|██████    | 60/100 [05:18<03:31,  5.30s/it]
 61%|██████    | 61/100 [05:24<03:27,  5.31s/it]
 62%|██████▏   | 62/100 [05:29<03:26,  5.43s/it]
 63%|██████▎   | 63/100 [05:35<03:19,  5.39s/it]
 64%|██████▍   | 64/100 [05:40<03:14,  5.41s/it]
 65%|██████▌   | 65/100 [05:46<03:08,  5.38s/it]
 66%|██████▌   | 66/100 [05:51<03:01,  5.35s/it]
 67%|██████▋   | 67/100 [05:56<02:55,  5.33s/it]
 68%|██████▊   | 68/100 [06:01<02:49,  5.31s/it]
 69%|██████▉   | 69/100 [06:07<02:46,  5.36s/it]
 70%|███████   | 70/100 [06:12<02:40,  5.34s/it]
 71%|███████   | 71/100 [06:18<02:36,  5.39s/it]
 72%|███████▏  | 72/100 [06:23<02:30,  5.37s/it]
 73%|███████▎  | 73/100 [06:28<02:24,  5.34s/it]
 74%|███████▍  | 74/100 [06:34<02:18,  5.32s/it]
 88%|████████▊ | 88/100 [07:48<01:03,  5.32s/it]
 89%|████████▉ | 89/100 [07:53<00:58,  5.34s/it]
 90%|█████████ | 90/100 [07:59<00:53,  5.35s/it]
 91%|█████████ | 91/100 [08:04<00:48,  5.43s/it]
 92%|█████████▏| 92/100 [08:10<00:43,  5.41s/it]
 93%|█████████▎| 93/100 [08:15<00:37,  5.38s/it]
 94%|█████████▍| 94/100 [08:20<00:32,  5.35s/it]
 95%|█████████▌| 95/100 [08:25<00:26,  5.34s/it]
 96%|█████████▌| 96/100 [08:31<00:21,  5.33s/it]
 97%|█████████▋| 97/100 [08:36<00:16,  5.37s/it]
 98%|█████████▊| 98/100 [08:41<00:10,  5.35s/it]
 99%|█████████▉| 99/100 [08:47<00:05,  5.32s/it]
100%|██████████| 100/100 [08:52<00:00,  5.33s/it]
{% endraw %} {% raw %}
BOBO_FET.visualize_embedding()
{% endraw %} {% raw %}
BOBO_FET.training_gif()
{% endraw %} {% raw %}
BOBO_FET.visualize_loss()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [9], in <cell line: 1>()
----> 1 BOBO_FET.visualize_loss()

File /gpfs/loomis/project/krishnaswamy_smita/kjm76/directed_graphs/directed_graphs/flow_embedding_training_utils.py:133, in FETrainer.visualize_loss(self, loss_type)
    131 if loss_type == "all":
    132   for key in self.losses.keys():
--> 133     plt.plot(self.losses[key])
    134   plt.legend(self.losses.keys(), loc='upper right')
    135   plt.title("loss")

File /gpfs/loomis/project/krishnaswamy_smita/kjm76/conda_envs/flowembed/lib/python3.9/site-packages/matplotlib/pyplot.py:2769, in plot(scalex, scaley, data, *args, **kwargs)
   2767 @_copy_docstring_and_deprecators(Axes.plot)
   2768 def plot(*args, scalex=True, scaley=True, data=None, **kwargs):
-> 2769     return gca().plot(
   2770         *args, scalex=scalex, scaley=scaley,
   2771         **({"data": data} if data is not None else {}), **kwargs)

File /gpfs/loomis/project/krishnaswamy_smita/kjm76/conda_envs/flowembed/lib/python3.9/site-packages/matplotlib/axes/_axes.py:1632, in Axes.plot(self, scalex, scaley, data, *args, **kwargs)
   1390 """
   1391 Plot y versus x as lines and/or markers.
   1392 
   (...)
   1629 (``'green'``) or hex strings (``'#008000'``).
   1630 """
   1631 kwargs = cbook.normalize_kwargs(kwargs, mlines.Line2D)
-> 1632 lines = [*self._get_lines(*args, data=data, **kwargs)]
   1633 for line in lines:
   1634     self.add_line(line)

File /gpfs/loomis/project/krishnaswamy_smita/kjm76/conda_envs/flowembed/lib/python3.9/site-packages/matplotlib/axes/_base.py:312, in _process_plot_var_args.__call__(self, data, *args, **kwargs)
    310     this += args[0],
    311     args = args[1:]
--> 312 yield from self._plot_args(this, kwargs)

File /gpfs/loomis/project/krishnaswamy_smita/kjm76/conda_envs/flowembed/lib/python3.9/site-packages/matplotlib/axes/_base.py:490, in _process_plot_var_args._plot_args(self, tup, kwargs, return_kwargs)
    488     y = _check_1d(xy[1])
    489 else:
--> 490     x, y = index_of(xy[-1])
    492 if self.axes.xaxis is not None:
    493     self.axes.xaxis.update_units(x)

File /gpfs/loomis/project/krishnaswamy_smita/kjm76/conda_envs/flowembed/lib/python3.9/site-packages/matplotlib/cbook/__init__.py:1614, in index_of(y)
   1612     pass
   1613 try:
-> 1614     y = _check_1d(y)
   1615 except (np.VisibleDeprecationWarning, ValueError):
   1616     # NumPy 1.19 will warn on ragged input, and we can't actually use it.
   1617     pass

File /gpfs/loomis/project/krishnaswamy_smita/kjm76/conda_envs/flowembed/lib/python3.9/site-packages/matplotlib/cbook/__init__.py:1306, in _check_1d(x)
   1304 x = _unpack_to_numpy(x)
   1305 if not hasattr(x, 'shape') or len(x.shape) < 1:
-> 1306     return np.atleast_1d(x)
   1307 else:
   1308     return x

File <__array_function__ internals>:180, in atleast_1d(*args, **kwargs)

File /gpfs/loomis/project/krishnaswamy_smita/kjm76/conda_envs/flowembed/lib/python3.9/site-packages/numpy/core/shape_base.py:65, in atleast_1d(*arys)
     63 res = []
     64 for ary in arys:
---> 65     ary = asanyarray(ary)
     66     if ary.ndim == 0:
     67         result = ary.reshape(1)

File /gpfs/loomis/project/krishnaswamy_smita/kjm76/conda_envs/flowembed/lib/python3.9/site-packages/torch/_tensor.py:757, in Tensor.__array__(self, dtype)
    755     return handle_torch_function(Tensor.__array__, (self,), self, dtype=dtype)
    756 if dtype is None:
--> 757     return self.numpy()
    758 else:
    759     return self.numpy().astype(dtype, copy=False)

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
{% endraw %} {% raw %}
BOBO_FET.losses['diffusion']
100
{% endraw %} {% raw %}
import glob
import ipywidgets as widgets
from PIL import Image
import os
import ipywidgets as widgets
import base64
frames = [Image.open(image) for image in glob.glob(f"visualizations/{BOBO_FET.timestamp}/*.jpg")]
frame_one = frames[0]
frame_one.save(f"{BOBO_FET.title}.gif", format="GIF", append_images=frames,
           save_all=True, duration=300, loop=0)
# display in jupyter notebook
b64 = base64.b64encode(open(f"{BOBO_FET.title}.gif",'rb').read()).decode('ascii')
display(widgets.HTML(f'<img src="data:image/gif;base64,{b64}" />'))
{% endraw %}